package it.uniroma2.tk;

import it.uniroma2.util.tree.Tree;

import java.util.HashMap;
import java.util.Vector;

import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;

/**
 * @author Fabio Massimo Zanzotto, Lorenzo Dell'Arciprete
 * This class implements the Route Tree Kernel computation between Tree objects 
 */
public class RouteTreeKernel implements KernelFunction<Tree> {

	public static double lambda = 1;
	public static boolean useProductions = false;
	private static int nodeCount = 0;
	private static HashMap<Integer, HashMap<String,Double>> deltaMatrixAll; 
	private static HashMap<Tree,Integer> nodeIndices; 
	
	public static double value(Tree a, Tree b) {
		deltaMatrixAll = new HashMap<Integer, HashMap<String,Double>>();
		nodeIndices = new HashMap<Tree,Integer>();
		nodeCount = 0;
		double sum = 0;
		a.initializeParents();
		b.initializeParents();
		Vector<Tree> aNodes = allNodes(a);
		Vector<Tree> bNodes = allNodes(b);
		for (int i = 1; i<=Math.min(maxDepth(a), maxDepth(b)); i++)
			for (Tree aa : aNodes)
				for (Tree bb : bNodes)
					sum += delta(aa,bb,i);
		return sum;
	}

	private static double delta(Tree a,Tree b,int index) {
		if (!deltaMatrixAll.containsKey(index))
			deltaMatrixAll.put(index, new HashMap<String,Double>());
		HashMap<String,Double> deltaMatrix = deltaMatrixAll.get(index);
		double k = 0;
		if (!nodeIndices.containsKey(a)) {
			nodeIndices.put(a,nodeCount);
			nodeCount++;
		}
		if (!nodeIndices.containsKey(b)) {
			nodeIndices.put(b,nodeCount);
			nodeCount++;
		}
		if (deltaMatrix.containsKey(nodeIndices.get(a) + ":" +nodeIndices.get(b))) {
			return deltaMatrix.get(nodeIndices.get(a) + ":" +nodeIndices.get(b));
		}

		if (deltaCompare(a, b)) {
			if (index == 1)
				k = 1;
			else {
				int chposA = chpos(a, index-1);
				int chposB = chpos(b, index-1);
				if (chposA >= 0 && chposB >= 0 && chposA == chposB)
					k = lambda*delta(a,b,index-1);
			}
		}
		deltaMatrix.put(nodeIndices.get(a) + ":" +nodeIndices.get(b),k);
		return k;
	}
	
	private static int chpos(Tree a, int level) {
		Tree child = a;
		Tree parent = a.getParent();
		if (parent == null)
			return -1;
		while (level > 1) {
			if (parent.getParent() == null)
				return -1;
			else {
				child = parent;
				parent = parent.getParent();
			}
			level--;
		}
		if (parent == null)
			return -1;
		else {
			return parent.getChildren().indexOf(child);
		}
	}
	
	private static boolean deltaCompare(Tree a, Tree b) {
		if (useProductions)
			return productionCompare(a, b);
		else
			return a.getRootLabel().equals(b.getRootLabel());		
	}
	
	private static boolean productionCompare(Tree a, Tree b) {
		if (!a.getRootLabel().equals(b.getRootLabel()))
			return false;
		if (a.getChildren().size() != b.getChildren().size() || a.getChildren().size() == 0)
			return false;
		for (int i=0; i<a.getChildren().size(); i++)
			if (!a.getChildren().get(i).getRootLabel().equals(b.getChildren().get(i).getRootLabel()))
				return false;
		return true;
	} 
	
	private static Vector<Tree> allNodes(Tree node) {
		Vector<Tree> all = new Vector<Tree>();
		all.add(node);
		for (Tree child : node.getChildren())
			all.addAll(allNodes(child));
		return all;
	}
	
	private static int maxDepth(Tree root) {
		int d = 0;
		for (Tree child : root.getChildren()) {
			int cd = maxDepth(child);
			if (cd > d)
				d = cd;
		}
		return d+1;
	}

	public double evaluate(Tree arg0, Tree arg1) {
		return value(arg0, arg1);
	}
	
}
